{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Processing the data (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.optim import AdamW\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "\n",
    "# Same as before\n",
    "checkpoint = \"bert-base-uncased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
    "sequences = [\n",
    "    \"I've been waiting for a HuggingFace course my whole life.\",\n",
    "    \"This course is amazing!\",\n",
    "]\n",
    "batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
    "\n",
    "# This is new\n",
    "batch[\"labels\"] = torch.tensor([1, 1])\n",
    "\n",
    "optimizer = AdamW(model.parameters())\n",
    "loss = model(**batch).loss\n",
    "loss.backward()\n",
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
       "        num_rows: 3668\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
       "        num_rows: 408\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
       "        num_rows: 1725\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'idx': 0,\n",
       " 'label': 1,\n",
       " 'sentence1': 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .',\n",
       " 'sentence2': 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_train_dataset = raw_datasets[\"train\"]\n",
    "raw_train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentence1': Value(dtype='string', id=None),\n",
       " 'sentence2': Value(dtype='string', id=None),\n",
       " 'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None),\n",
       " 'idx': Value(dtype='int32', id=None)}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_train_dataset.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "checkpoint = \"bert-base-uncased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "tokenized_sentences_1 = tokenizer(raw_datasets[\"train\"][\"sentence1\"])\n",
    "tokenized_sentences_2 = tokenizer(raw_datasets[\"train\"][\"sentence2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{ \n",
       "  'input_ids': [101, 2023, 2003, 1996, 2034, 6251, 1012, 102, 2023, 2003, 1996, 2117, 2028, 1012, 102],\n",
       "  'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],\n",
       "  'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
       "}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\"This is the first sentence.\", \"This is the second one.\")\n",
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]', 'this', 'is', 'the', 'first', 'sentence', '.', '[SEP]', 'this', 'is', 'the', 'second', 'one', '.', '[SEP]']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_dataset = tokenizer(\n",
    "    raw_datasets[\"train\"][\"sentence1\"],\n",
    "    raw_datasets[\"train\"][\"sentence2\"],\n",
    "    padding=True,\n",
    "    truncation=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_function(example):\n",
    "    return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence1', 'sentence2', 'token_type_ids'],\n",
       "        num_rows: 3668\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence1', 'sentence2', 'token_type_ids'],\n",
       "        num_rows: 408\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence1', 'sentence2', 'token_type_ids'],\n",
       "        num_rows: 1725\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[50, 59, 47, 67, 59, 50, 62, 32]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples = tokenized_datasets[\"train\"][:8]\n",
    "samples = {k: v for k, v in samples.items() if k not in [\"idx\", \"sentence1\", \"sentence2\"]}\n",
    "[len(x) for x in samples[\"input_ids\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': torch.Size([8, 67]),\n",
       " 'input_ids': torch.Size([8, 67]),\n",
       " 'token_type_ids': torch.Size([8, 67]),\n",
       " 'labels': torch.Size([8])}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = data_collator(samples)\n",
    "{k: v.shape for k, v in batch.items()}"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Processing the data (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
